import pandas as pd
import numpy as np
from tqdm import tqdm
import os
from collections import defaultdict
from itertools import product

import utils
from utils import LegislatorsDataset

from sklearn.model_selection import cross_val_score, GroupKFold
from sklearn.ensemble import RandomForestClassifier

from utils import permute_data, ConditionalMajorityVote, make_indicator, \
                  party_indicator, load_csv, to_csv, predict_and_score


np.random.seed(35897150)
keep_abstain = False
only_bills = True
only_house = False
congresses_modeled = [110, 111, 112, 113, 114, 115, 116, 117 ]
congresses = [ 113, 114, 115, 116, 117 ]

merge_third_parties = True
name_field = 'Name_merged'

n_components = 25
n_components_bill_embed = 25
cv_folds = 10
analysis_dir = '.'
embed_field_name = 'PCAEmbed'
bill_embed_field_name = 'embedding_LongFormer_PCA'

n_jobs = 3

results_dirname = 'results-vote_prediction'

try: os.mkdir(analysis_dir)
except FileExistsError: pass

n_components = n_components if embed_field_name.startswith('PCA') else None

print('loading in bill models and vote data')
try:
  print(df_votes_.head())
except: 
  df_votes_ = load_csv(os.path.join(analysis_dir, \
                    '%ith_to_%ith_congress-vote_by_member_procced.csv' \
                    % (congresses_modeled[0], congresses_modeled[-1])))


try: print(df_members.head())
except:
 try:
  df_votes = load_csv(os.path.join(analysis_dir, \
                      '%ith_to_%ith_congress-vote_by_member_procced.csv' \
                      % (congresses[0], congresses[-1])))
  df_members = load_csv(os.path.join(analysis_dir, \
                      '%ith_to_%ith_congress-members-procced.csv' \
                      % (congresses_modeled[0], congresses_modeled[-1])))
 except FileNotFoundError:
  print('did not find data... compiling...')
  df_members = []
  for congress in congresses_modeled:
    print('\t%ith congress' % congress)
    df_members_ = load_csv(os.path.join(analysis_dir, \
                           '%ith_congress-members-procced.csv' % congress))
    df_members.append(df_members_)
  
  df_members = pd.concat(df_members)
  to_csv(df_members, os.path.join(analysis_dir, \
                      '%ith_to_%ith_congress-members-procced.csv' \
                      % (congresses_modeled[0], congresses_modeled[-1])))


df_votes = df_votes_.copy()

if merge_third_parties:
  #relabel parties
  party2label = defaultdict(lambda: 'I', { 'Democrat': 'D', 'Republican': 'R' })
  df_votes['legislator_party'] = df_votes.party.apply(lambda p: party2label[p])
  df_votes['sponsor_party'] = df_votes.sponsor_party.apply(
                                     lambda s: (s if s in ['D', 'R'] else 'I'))
else:
  df_votes['legislator_party'] = np.array(df_votes.party)

df_votes = df_votes[np.isin(df_votes.congress, congresses)]

#initialize a Legislators dataset
Legislators = utils.load(os.path.join(analysis_dir, 'Legislators.lds'))

df_votes = df_votes[np.isin(df_votes[name_field], getattr(Legislators, name_field))]

#coded votes: 	vote_NVmerged merges Present and Abstain
#		cast_description has Yea, Nay, Present, and Abstain

if not keep_abstain:
  df_votes = df_votes[df_votes.cast_description != 'Not Voting (Abstention)']

if only_bills:
  df_votes = df_votes[np.isin(df_votes.vote_question, \
                              [ 'On Passage',  #House
                                #'On Motion to Suspend the Rules and Pass', \
                                #'On Agreeing to the Resolution', \
                                #'On Motion to Suspend the Rules and Pass, as Amended', \
                                #'On Motion to Recommit with Instructions', \
                                'On Passage of the Bill', #Senate
                                'On Agreeing to the Resolution', #House
                                'On the Resolution', #Senate
                                #'On the Motion', 
                                #'On the Cloture Motion', \
                                #'On the Amendment', \
                                #'On Agreeing to the Senate Amendment', \
                                #'On Agreeing to the Resolution, as Amended', \
                                #'On Overriding the Veto', \
                                'On the Joint Resolution', #Senate
                                #'On Consideration of the Joint Resolution', #House
                                #'On the Nomination', \
                               ])]

df_votes = df_votes.dropna(subset=['bill_id', bill_embed_field_name])

#permute the dataset to ensure folds aren't correlated with session
df_votes = df_votes.iloc[np.random.permutation(df_votes.shape[0])]

#set up pipeline for turning vote records into ordered class labels (int)
votecodes = { 'Nay': 0, 'Present': 1, 'Yea': 2 }
if keep_abstain: votecodes['Not Voting (Abstention)'] = 1

vote2ordinal__ = np.vectorize(lambda vote: votecodes[vote])
def vote2ordinal_(y):
  for i in range(y.min(), y.max()+1):
    if not np.any(y == i):
      y[y > i] -= 1
  return y

vote2ordinal = lambda y: vote2ordinal_(vote2ordinal__(y))

scores_dict = defaultdict(lambda: defaultdict(list))
scores_dict_ = {'model': [], 'chamber': [], 'roll_call': [], 'policy_area': [], 'bill_description': [], 'score': []}

scores_dict_bysubject = {'model': [], 'chamber': [], 'roll_call': [], 'crs_subject': [], 'score': []}

if only_house:
  df_votes = df_votes[df_votes.chamber == 'House']


y = vote2ordinal(df_votes.cast_description) #recode class target to int

#organize the data for the task
X_legislator, X_billtext = [], []

print('gathering legislator embeddings')
X_legislator = np.array([ getattr(Legislators[name], embed_field_name) \
                                      for name in tqdm(df_votes[name_field]) ]) 
X_legislator_state = np.array([ getattr(Legislators[name], 'state_abbrev') \
                                      for name in tqdm(df_votes[name_field]) ])
             
df_votes['state'] = X_legislator_state
X_legislator_state = make_indicator(X_legislator_state)


X_legislator_party = make_indicator(np.array(df_votes.legislator_party))
if 'PCA' in embed_field_name or 'RAND' in embed_field_name:
  X_legislator = X_legislator[:,:n_components]

X_billtext = np.array(list(getattr(df_votes, bill_embed_field_name)))
if bill_embed_field_name.endswith('PCA'):
  X_billtext = X_billtext[:,:n_components_bill_embed]

X_sponsor_party = make_indicator(np.array(df_votes.sponsor_party))
X_chamber = make_indicator(np.array(df_votes.chamber))

X_bills = np.array(df_votes.bill_id) #for cross-val by bill
X_names = np.array(df_votes[name_field]) #for cross-val by legislator
X_congress = np.array(df_votes.congress)


X_chamber = make_indicator(np.array(df_votes.chamber))

randstate = 420

#permute the data
data_ = [X_legislator, y, X_billtext, X_legislator_party, X_sponsor_party, X_bills, X_names, X_congress, X_legislator_state ]
X_legislator_, y_, X_billtext_, X_legislator_party_, \
               X_sponsor_party_, X_bills_, \
               X_names_, X_congress_, X_legislator_state_ = permute_data(*data_)



randstate = 1890547109

#### TEST THE SPECIFICATIONS IN `varspecs` ###########################

vars_ = { 'LEGFIN': X_legislator, 'BILLVEC': X_billtext, 
          'LEGPRTY': X_legislator_party, 'SPNPRTY': X_sponsor_party, 
          'CHAMBER': X_chamber, 'LEGSTATE': X_legislator_state }

varspecs = [ 'LEGFIN_BILLVEC_LEGPRTY_SPNPRTY', 'LEGFIN_BILLVEC_SPNPRTY', 
             'LEGPRTY_BILLVEC_SPNPRTY', 'LEGPRTY_SPNPRTY', 'LEGPRTY_BILLVEC', 
             'LEGSTATE_BILLVEC', 'LEGSTATE_BILLVEC_SPNPRTY', 
             'LEGPRTY_LEGSTATE_BILLVEC', 'LEGPRTY_LEGSTATE_SPNPRTY', 
             'LEGFIN_SPNPRTY', 'BILLVEC', 'LEGPRTY', 'LEGFIN_BILLVEC', 
             'LEGFIN_LEGSTATE_BILLVEC_SPNPRTY', 'LEGPRTY_LEGSTATE_BILLVEC_SPNPRTY', 
             'LEGFIN_LEGPRTY_LEGSTATE_BILLVEC_SPNPRTY', 'LEGFIN_LEGSTATE_SPNPRTY' ]

cv_groups = {'bill': X_bills, 'congress': X_congress, 'legislator': X_names}

df_vote_results = df_votes[['chamber', 'rollnumber', 'bill_id', 'Name', 
                              'state', 'crs_policy_area', 'congress', 
                              'vote_result', 'summary_short', 'sponsor_party', 
                              'yea_count', 'nay_count', 'vote_question']].copy()

try: os.mkdir(os.path.join(analysis_dir, results_dirname))
except FileExistsError: pass

#iterate across the evaluations. this is the cross product of 
#	cv regimes (by-bill, by-congress, by-legislator) and 
#	the model specifications in `varspecs`
import time
results_files = []
for (cv_group, X_groups), varspec in list(product(list(cv_groups.items()), varspecs))[:]:
 out_fname_ = 'results-%s_cv-%s.csv' % (cv_group, varspec)
 if out_fname_ not in os.listdir(os.path.join(analysis_dir, results_dirname)):
  print('evaluating %s with by-%s cv' % (varspec, cv_group))
  cv_splitter = GroupKFold(n_splits=min(len(np.unique(X_groups)), cv_folds))
  X = np.hstack([ vars_[var_] for var_ in varspec.split('_') ])
  df_vote_results_ = df_vote_results.copy()
  clf = RandomForestClassifier(n_estimators=200, random_state=randstate)#@@@@set random state?
  #n_jobs_ = 1 if 'LEGSTATE' in varspec else n_jobs
  n_jobs_ = 1 if X.shape[-1] > 125 else n_jobs
  #clf = net(X, y)
  start_time = time.process_time()
  predictions, score = predict_and_score(clf, X, y, groups=X_groups, \
                                                  cv=cv_splitter, n_jobs=n_jobs_)
  #results_by_var_group[vargroups] = [ scores.mean(), scores, predictions, df_vote_results ]
  print('time: ', time.process_time() - start_time)
  print(cv_group, varspec, score)
  
  df_vote_results_['cv_groups'] = cv_group
  df_vote_results_['model'] = varspec
  df_vote_results_['prediction'] = predictions
  df_vote_results_['accuracy'] = np.array(predictions == y, dtype=np.int8)
  
  out_fname = os.path.join(analysis_dir, results_dirname, out_fname_)
  df_vote_results_.to_csv(out_fname, index=False)
  
  print('results saved to %s' % out_fname)
  results_files.append(out_fname)

pd.set_option('display.max_rows', 500)

results_files = [ os.path.join(analysis_dir, results_dirname, fname) \
                         for fname in os.listdir(os.path.join(analysis_dir, \
                         results_dirname)) if fname.startswith('results-') ]


for fname in results_files:
  new_df = pd.read_csv(fname)
  try:
    df_vote_results = pd.concat([df_vote_results, new_df])
  except NameError:
    df_vote_results = new_df

df_result_summary = df_vote_results[['cv_groups', 'model', \
                                     'accuracy']].groupby(['cv_groups', \
                                     'model']).mean()

df_result_summary_by_chamber = df_vote_results[['cv_groups', \
                                      'chamber', 'model', \
                                      'accuracy']].groupby(['cv_groups', \
                                      'chamber', 'model']).mean()


#define function to get results for specific models
acc, chamberacc = df_result_summary.reset_index(), \
                  df_result_summary_by_chamber.reset_index()

acc.to_csv(os.path.join(analysis_dir, results_dirname, 'RESULTS_TABLE.csv'))
chamberacc.to_csv(os.path.join(analysis_dir, results_dirname, 'RESULTS_TABLE-BY_CHAMBER.csv'))


